-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: sebulba rec ippo #1142
base: develop
Are you sure you want to change the base?
Feat: sebulba rec ippo #1142
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the system looks correct and reasonable. Well done Simon! I just kept minor requests :)
@@ -0,0 +1,910 @@ | |||
# Copyright 2022 InstaDeep Ltd. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can update the typings in Pipeline mava/utils/sebulba.py
to be Union[PPOTransition, RNNPPOTransition]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes errors in the pre-commit. For now I changed both sebulba systems to use the MavaTransition
type-var but this is probably a temporary solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please make an issue for this. I think the best solution is to make a protocol with the all the common things in a transition (actions, obs, done, reward). The challenge is that named tuples don't seem to work with protocols so we'd likely need to switch to a flax/chex dataclass
log_prob = actor_policy.log_prob(action) | ||
# It may be faster to calculate the values in the learner as | ||
# then we won't need to pass critic params to actors. | ||
# value = critic_apply_fn(params.critic_params, observation).squeeze() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you can remove this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Pretty much good to go except a few minor style changes to bring it up to date with the latest PPO changes that went in at the end of last year
|
||
def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: | ||
"""Update the network for a single minibatch.""" | ||
# UNPACK TRAIN STATE AND BATCH INFO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line with PPO updates I did last year
# UNPACK TRAIN STATE AND BATCH INFO |
key: chex.PRNGKey, | ||
) -> Tuple: | ||
"""Calculate the actor loss.""" | ||
# RERUN NETWORK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# RERUN NETWORK | |
# Rerun network |
targets: chex.Array, | ||
) -> Tuple: | ||
"""Calculate the critic loss.""" | ||
# RERUN NETWORK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# RERUN NETWORK | |
# Rerun network |
critic_params, traj_batch.hstates.critic_hidden_state[0], obs_and_done | ||
) | ||
|
||
# CALCULATE VALUE LOSS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# CALCULATE VALUE LOSS | |
# Calculate value loss |
loss_actor1 = ratio * gae | ||
loss_actor2 = ( | ||
jnp.clip( | ||
ratio, | ||
1.0 - config.system.clip_eps, | ||
1.0 + config.system.clip_eps, | ||
) | ||
* gae | ||
) | ||
loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | ||
loss_actor = loss_actor.mean() | ||
# The seed will be used in the TanhTransformedDistribution: | ||
entropy = actor_policy.entropy(seed=key).mean() | ||
|
||
total_loss = loss_actor - config.system.ent_coef * entropy | ||
return total_loss, (loss_actor, entropy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss_actor1 = ratio * gae | |
loss_actor2 = ( | |
jnp.clip( | |
ratio, | |
1.0 - config.system.clip_eps, | |
1.0 + config.system.clip_eps, | |
) | |
* gae | |
) | |
loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | |
loss_actor = loss_actor.mean() | |
# The seed will be used in the TanhTransformedDistribution: | |
entropy = actor_policy.entropy(seed=key).mean() | |
total_loss = loss_actor - config.system.ent_coef * entropy | |
return total_loss, (loss_actor, entropy) | |
actor_loss1 = ratio * gae | |
actor_loss2 = ( | |
jnp.clip( | |
ratio, | |
1.0 - config.system.clip_eps, | |
1.0 + config.system.clip_eps, | |
) | |
* gae | |
) | |
actor_loss = -jnp.minimum(actor_loss1, actor_loss2) | |
actor_loss = actor_loss.mean() | |
# The seed will be used in the TanhTransformedDistribution: | |
entropy = actor_policy.entropy(seed=key).mean() | |
total_loss = actor_loss - config.system.ent_coef * entropy | |
return total_loss, (actor_loss, entropy) |
|
||
# Calculate critic loss | ||
critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) | ||
critic_loss_info, critic_grads = critic_grad_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
critic_loss_info, critic_grads = critic_grad_fn( | |
value_loss_info, critic_grads = critic_grad_fn( |
critic_grads, critic_loss_info = jax.lax.pmean( | ||
(critic_grads, critic_loss_info), axis_name="learner_devices" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
critic_grads, critic_loss_info = jax.lax.pmean( | |
(critic_grads, critic_loss_info), axis_name="learner_devices" | |
) | |
critic_grads, value_loss_info = jax.lax.pmean( | |
(critic_grads, value_loss_info), axis_name="learner_devices" | |
) |
actor_total_loss, (actor_loss, entropy) = actor_loss_info | ||
critic_total_loss, (value_loss) = critic_loss_info | ||
total_loss = critic_total_loss + actor_total_loss | ||
loss_info = { | ||
"total_loss": total_loss, | ||
"value_loss": value_loss, | ||
"actor_loss": actor_loss, | ||
"entropy": entropy, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actor_total_loss, (actor_loss, entropy) = actor_loss_info | |
critic_total_loss, (value_loss) = critic_loss_info | |
total_loss = critic_total_loss + actor_total_loss | |
loss_info = { | |
"total_loss": total_loss, | |
"value_loss": value_loss, | |
"actor_loss": actor_loss, | |
"entropy": entropy, | |
} | |
actor_loss, (_, entropy) = actor_loss_info | |
value_loss, (unscaled_value_loss) = value_loss_info | |
total_loss = actor_loss + value_loss | |
loss_info = { | |
"total_loss": total_loss, | |
"value_loss": unscaled_value_loss, | |
"actor_loss": actor_loss, | |
"entropy": entropy, | |
} |
batch = tree.map( | ||
lambda x: x.reshape( | ||
config.system.recurrent_chunk_size, | ||
num_learner_envs * num_recurrent_chunks, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you extract this into a variable called batch_size
. Will make it more clear and it's also used below
# batch_size = config.system.rollout_length * num_learner_envs | ||
# permutation = jax.random.permutation(shuffle_key, batch_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# batch_size = config.system.rollout_length * num_learner_envs | |
# permutation = jax.random.permutation(shuffle_key, batch_size) |
Sebulba implementation of recurrent IPPO.